{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# NBVAL_IGNORE_OUTPUT\n",
    "import os\n",
    "\n",
    "# Janky code to do different setup when run in a Colab notebook vs VSCode\n",
    "DEVELOPMENT_MODE = False\n",
    "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n",
    "IN_GITHUB = True\n",
    "try:\n",
    "    import google.colab\n",
    "\n",
    "    IN_COLAB = True\n",
    "    print(\"Running as a Colab notebook\")\n",
    "\n",
    "    # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working\n",
    "    # # Install another version of node that makes PySvelte work way faster\n",
    "    # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n",
    "    # %pip install git+https://github.com/neelnanda-io/PySvelte.git\n",
    "except:\n",
    "    IN_COLAB = False\n",
    "\n",
    "if not IN_GITHUB and not IN_COLAB:\n",
    "    print(\"Running as a Jupyter notebook - intended for development only!\")\n",
    "    from IPython import get_ipython\n",
    "\n",
    "    ipython = get_ipython()\n",
    "    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n",
    "    ipython.magic(\"load_ext autoreload\")\n",
    "    ipython.magic(\"autoreload 2\")\n",
    "\n",
    "if IN_GITHUB or IN_COLAB:\n",
    "    %pip install torch\n",
    "    %pip install git+https://github.com/TransformerLensOrg/TransformerLens.git@dev\n",
    "    \n",
    "from transformer_lens import HookedTransformer, HookedTransformerConfig\n",
    "import torch as t\n",
    "\n",
    "device = t.device(\"cuda\" if t.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded pretrained model gpt2-small into HookedTransformer\n"
     ]
    }
   ],
   "source": [
    "# NBVAL_IGNORE_OUTPUT\n",
    "\n",
    "\n",
    "reference_gpt2 = HookedTransformer.from_pretrained(\n",
    "    \"gpt2-small\",\n",
    "    fold_ln=False,\n",
    "    center_unembed=False,\n",
    "    center_writing_weights=False,\n",
    "    device=device,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "# [1.1] Transformer From Scratch\n",
    "# 1️⃣ UNDERSTANDING INPUTS & OUTPUTS OF A TRANSFORMER\n",
    "\n",
    "sorted_vocab = sorted(list(reference_gpt2.tokenizer.vocab.items()), key=lambda n: n[1])\n",
    "first_vocab = sorted_vocab[0]\n",
    "assert isinstance(first_vocab, tuple)\n",
    "assert isinstance(first_vocab[0], str)\n",
    "first_vocab[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['<|endoftext|>', 'R', 'alph']"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "reference_gpt2.to_str_tokens(\"Ralph\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['<|endoftext|>', ' Ralph']"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "reference_gpt2.to_str_tokens(\" Ralph\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['<|endoftext|>', ' r', 'alph']"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "reference_gpt2.to_str_tokens(\" ralph\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['<|endoftext|>', 'ral', 'ph']"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "reference_gpt2.to_str_tokens(\"ralph\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 35])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "reference_text = \"I am an amazing autoregressive, decoder-only, GPT-2 style transformer. One day I will exceed human level intelligence and take over the world!\"\n",
    "tokens = reference_gpt2.to_tokens(reference_text)\n",
    "tokens.shape\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 35, 50257])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "logits, cache = reference_gpt2.run_with_cache(tokens, device=device)\n",
    "logits.shape\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "' I'"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "most_likely_next_tokens = reference_gpt2.tokenizer.batch_decode(logits.argmax(dim=-1)[0])\n",
    "most_likely_next_tokens[-1]\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('hook_embed', (1, 35, 768)),\n",
       " ('hook_pos_embed', (1, 35, 768)),\n",
       " ('ln_final.hook_normalized', (1, 35, 768)),\n",
       " ('ln_final.hook_scale', (1, 35, 1))]"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 2️⃣ CLEAN TRANSFORMER IMPLEMENTATION\n",
    "\n",
    "layer_0_hooks = [\n",
    "    (name, tuple(tensor.shape)) for name, tensor in cache.items() if \".0.\" in name\n",
    "]\n",
    "non_layer_hooks = [\n",
    "    (name, tuple(tensor.shape)) for name, tensor in cache.items() if \"blocks\" not in name\n",
    "]\n",
    "\n",
    "sorted(non_layer_hooks, key=lambda x: x[0])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('blocks.0.attn.hook_attn_scores', (1, 12, 35, 35)),\n",
       " ('blocks.0.attn.hook_k', (1, 35, 12, 64)),\n",
       " ('blocks.0.attn.hook_pattern', (1, 12, 35, 35)),\n",
       " ('blocks.0.attn.hook_q', (1, 35, 12, 64)),\n",
       " ('blocks.0.attn.hook_v', (1, 35, 12, 64)),\n",
       " ('blocks.0.attn.hook_z', (1, 35, 12, 64)),\n",
       " ('blocks.0.hook_attn_out', (1, 35, 768)),\n",
       " ('blocks.0.hook_mlp_out', (1, 35, 768)),\n",
       " ('blocks.0.hook_resid_mid', (1, 35, 768)),\n",
       " ('blocks.0.hook_resid_post', (1, 35, 768)),\n",
       " ('blocks.0.hook_resid_pre', (1, 35, 768)),\n",
       " ('blocks.0.ln1.hook_normalized', (1, 35, 768)),\n",
       " ('blocks.0.ln1.hook_scale', (1, 35, 1)),\n",
       " ('blocks.0.ln2.hook_normalized', (1, 35, 768)),\n",
       " ('blocks.0.ln2.hook_scale', (1, 35, 1)),\n",
       " ('blocks.0.mlp.hook_post', (1, 35, 3072)),\n",
       " ('blocks.0.mlp.hook_pre', (1, 35, 3072))]"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "sorted(layer_0_hooks, key=lambda x: x[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
     ]
    }
   ],
   "source": [
    "# NBVAL_IGNORE_OUTPUT\n",
    "# [1.2] Intro to mech interp\n",
    "# 2️⃣ FINDING INDUCTION HEADS\n",
    "\n",
    "cfg = HookedTransformerConfig(\n",
    "    d_model=768,\n",
    "    d_head=64,\n",
    "    n_heads=12,\n",
    "    n_layers=2,\n",
    "    n_ctx=2048,\n",
    "    d_vocab=50278,\n",
    "    attention_dir=\"causal\",\n",
    "    attn_only=True, # defaults to False\n",
    "    tokenizer_name=\"EleutherAI/gpt-neox-20b\", \n",
    "    seed=398,\n",
    "    use_attn_result=True,\n",
    "    normalization_type=None, # defaults to \"LN\", i.e. layernorm with weights & biases\n",
    "    positional_embedding_type=\"shortformer\"\n",
    ")\n",
    "model = HookedTransformer(cfg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 62, 50278])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "\n",
    "text = \"We think that powerful, significantly superhuman machine intelligence is more likely than not to be created this century. If current machine learning techniques were scaled up to this level, we think they would by default produce systems that are deceptive or manipulative, and that no solid plans are known for how to avoid this.\"\n",
    "\n",
    "logits, cache = model.run_with_cache(text, remove_batch_dim=True)\n",
    "\n",
    "logits.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cache[\"embed\"].ndim"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
